{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "basic_classification.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MhoQ0WE77laV",
        "colab_type": "text"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_ckMIh7O7s6D",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vasWnqRgy1H4",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title MIT License\n",
        "#\n",
        "# Copyright (c) 2017 François Chollet\n",
        "#\n",
        "# Permission is hereby granted, free of charge, to any person obtaining a\n",
        "# copy of this software and associated documentation files (the \"Software\"),\n",
        "# to deal in the Software without restriction, including without limitation\n",
        "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n",
        "# and/or sell copies of the Software, and to permit persons to whom the\n",
        "# Software is furnished to do so, subject to the following conditions:\n",
        "#\n",
        "# The above copyright notice and this permission notice shall be included in\n",
        "# all copies or substantial portions of the Software.\n",
        "#\n",
        "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
        "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
        "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n",
        "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
        "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n",
        "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n",
        "# DEALINGS IN THE SOFTWARE."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jYysdyb-CaWM",
        "colab_type": "text"
      },
      "source": [
        "# 训练您的第一个神经网络: 基本分类"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Uhzt6vVIB2",
        "colab_type": "text"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/r1/tutorials/keras/basic_classification.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/r1/tutorials/keras/basic_classification.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FbVhjPpzn6BM",
        "colab_type": "text"
      },
      "source": [
        "本指南训练了一个神经网络模型，来对服装图像进行分类，例如运动鞋和衬衫。如果您不了解所有细节也不需要担心，这是一个对完整TensorFlow项目的简要概述，相关的细节会在需要时进行解释\n",
        "\n",
        "本指南使用[tf.keras](https://www.tensorflow.org/r1/guide/keras)，这是一个用于在TensorFlow中构建和训练模型的高级API。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dzLKpmZICaWN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "# 导入TensorFlow和tf.keras\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "# 导入辅助库\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "print(tf.__version__)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yR0EdgrLCaWR",
        "colab_type": "text"
      },
      "source": [
        "## 导入Fashion MNIST数据集"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DLdCchMdCaWQ",
        "colab_type": "text"
      },
      "source": [
        "本指南使用[Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) 数据集，其中包含了10个类别中共70,000张灰度图像。图像包含了低分辨率（28 x 28像素）的单个服装物品，如下所示:\n",
        "\n",
        "<table>\n",
        "  <tr><td>\n",
        "    <img src=\"https://tensorflow.org/images/fashion-mnist-sprite.png\"\n",
        "         alt=\"Fashion MNIST sprite\"  width=\"600\">\n",
        "  </td></tr>\n",
        "  <tr><td align=\"center\">\n",
        "    <b>Figure 1.</b> <a href=\"https://github.com/zalandoresearch/fashion-mnist\">Fashion-MNIST samples</a> (by Zalando, MIT License).<br/>&nbsp;\n",
        "  </td></tr>\n",
        "</table>\n",
        "\n",
        "Fashion MNIST 旨在替代传统的[MNIST](http://yann.lecun.com/exdb/mnist/)数据集 — 它经常被作为机器学习在计算机视觉方向的\"Hello, World\"。MNIST数据集包含手写数字（0,1,2等）的图像，其格式与我们在此处使用的服装相同。\n",
        "\n",
        "本指南使用Fashion MNIST进行多样化，因为它比普通的MNIST更具挑战性。两个数据集都相对较小，用于验证算法是否按预期工作。它们是测试和调试代码的良好起点。\n",
        "\n",
        "我们将使用60,000张图像来训练网络和10,000张图像来评估网络模型学习图像分类任务的准确程度。您可以直接从TensorFlow使用Fashion MNIST，只需导入并加载数据"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7MqDQO0KCaWS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fashion_mnist = keras.datasets.fashion_mnist\n",
        "\n",
        "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t9FDsUlxCaWW",
        "colab_type": "text"
      },
      "source": [
        "加载数据集并返回四个NumPy数组:\n",
        "\n",
        "* `train_images`和`train_labels`数组是*训练集*—这是模型用来学习的数据。\n",
        "* 模型通过*测试集*进行测试, 即`test_images`与 `test_labels`两个数组。\n",
        "\n",
        "图像是28x28 NumPy数组，像素值介于0到255之间。*labels*是一个整数数组，数值介于0到9之间。这对应了图像所代表的服装的*类别*:\n",
        "\n",
        "<table>\n",
        "  <tr>\n",
        "    <th>标签</th>\n",
        "    <th>类别</th>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>0</td>\n",
        "    <td>T-shirt/top</td>\n",
        "  </tr>\n",
        "  <tr>\n",
        "    <td>1</td>\n",
        "    <td>Trouser</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>2</td>\n",
        "    <td>Pullover</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>3</td>\n",
        "    <td>Dress</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>4</td>\n",
        "    <td>Coat</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>5</td>\n",
        "    <td>Sandal</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>6</td>\n",
        "    <td>Shirt</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>7</td>\n",
        "    <td>Sneaker</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>8</td>\n",
        "    <td>Bag</td>\n",
        "  </tr>\n",
        "    <tr>\n",
        "    <td>9</td>\n",
        "    <td>Ankle boot</td>\n",
        "  </tr>\n",
        "</table>\n",
        "\n",
        "每个图像都映射到一个标签。由于*类别名称*不包含在数据集中,因此把他们存储在这里以便在绘制图像时使用:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IjnLH5S2CaWx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
        "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Brm0b_KACaWX",
        "colab_type": "text"
      },
      "source": [
        "## 探索数据\n",
        "\n",
        "让我们在训练模型之前探索数据集的格式。以下显示训练集中有60,000个图像，每个图像表示为28 x 28像素:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zW5k_xz1CaWX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_images.shape"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cIAcvQqMCaWf",
        "colab_type": "text"
      },
      "source": [
        "同样，训练集中有60,000个标签:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TRFYHB2mCaWb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "len(train_labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YSlYxFuRCaWk",
        "colab_type": "text"
      },
      "source": [
        "每个标签都是0到9之间的整数:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XKnCTHz4CaWg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_labels"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TMPI88iZpO2T",
        "colab_type": "text"
      },
      "source": [
        "测试集中有10,000个图像。 同样，每个图像表示为28×28像素:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2KFnYlcwCaWl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "test_images.shape"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rd0A0Iu0CaWq",
        "colab_type": "text"
      },
      "source": [
        "测试集包含10,000个图像标签:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iJmPr5-ACaWn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "len(test_labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ES6uQoLKCaWr",
        "colab_type": "text"
      },
      "source": [
        "## 数据预处理\n",
        "\n",
        "在训练网络之前必须对数据进行预处理。 如果您检查训练集中的第一个图像，您将看到像素值落在0到255的范围内:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m4VEw8Ud9Quh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.figure()\n",
        "plt.imshow(train_images[0])\n",
        "plt.colorbar()\n",
        "plt.grid(False)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wz7l27Lz9S1P",
        "colab_type": "text"
      },
      "source": [
        "在馈送到神经网络模型之前，我们将这些值缩放到0到1的范围。为此，我们将像素值值除以255。重要的是，对*训练集*和*测试集*要以相同的方式进行预处理:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bW5WzIPlCaWv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_images = train_images / 255.0\n",
        "\n",
        "test_images = test_images / 255.0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ee638AlnCaWz",
        "colab_type": "text"
      },
      "source": [
        "显示*训练集*中的前25个图像，并在每个图像下方显示类名。验证数据格式是否正确，我们是否已准备好构建和训练网络。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oZTImqg_CaW1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.figure(figsize=(10,10))\n",
        "for i in range(25):\n",
        "    plt.subplot(5,5,i+1)\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    plt.grid(False)\n",
        "    plt.imshow(train_images[i], cmap=plt.cm.binary)\n",
        "    plt.xlabel(class_names[train_labels[i]])\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "59veuiEZCaW4",
        "colab_type": "text"
      },
      "source": [
        "## 构建模型\n",
        "\n",
        "构建神经网络需要配置模型的层，然后编译模型。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gxg1XGm0eOBy",
        "colab_type": "text"
      },
      "source": [
        "### 设置网络层\n",
        "\n",
        "一个神经网络最基本的组成部分便是*网络层*。网络层从提供给他们的数据中提取表示，并期望这些表示对当前的问题更加有意义\n",
        "\n",
        "大多数深度学习是由串连在一起的网络层所组成。大多数网络层，例如`tf.keras.layers.Dense`，具有在训练期间学习的参数。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9ODch-OFCaW4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = keras.Sequential([\n",
        "    keras.layers.Flatten(input_shape=(28, 28)),\n",
        "    keras.layers.Dense(128, activation=tf.nn.relu),\n",
        "    keras.layers.Dense(10, activation=tf.nn.softmax)\n",
        "])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gut8A_7rCaW6",
        "colab_type": "text"
      },
      "source": [
        "网络中的第一层, `tf.keras.layers.Flatten`, 将图像格式从一个二维数组(包含着28x28个像素)转换成为一个包含着28 * 28 = 784个像素的一维数组。可以将这个网络层视为它将图像中未堆叠的像素排列在一起。这个网络层没有需要学习的参数;它仅仅对数据进行格式化。\n",
        "\n",
        "在像素被展平之后，网络由一个包含有两个`tf.keras.layers.Dense`网络层的序列组成。他们被称作稠密链接层或全连接层。 第一个`Dense`网络层包含有128个节点(或被称为神经元)。第二个(也是最后一个)网络层是一个包含10个节点的*softmax*层—它将返回包含10个概率分数的数组，总和为1。每个节点包含一个分数，表示当前图像属于10个类别之一的概率。\n",
        "\n",
        "### 编译模型\n",
        "\n",
        "在模型准备好进行训练之前，它还需要一些配置。这些是在模型的*编译(compile)*步骤中添加的:\n",
        "\n",
        "* *损失函数* —这可以衡量模型在培训过程中的准确程度。 我们希望将此函数最小化以\"驱使\"模型朝正确的方向拟合。\n",
        "* *优化器* —这就是模型根据它看到的数据及其损失函数进行更新的方式。\n",
        "* *评价方式* —用于监控训练和测试步骤。以下示例使用*准确率(accuracy)*，即正确分类的图像的分数。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lhan11blCaW7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model.compile(optimizer='adam',\n",
        "              loss='sparse_categorical_crossentropy',\n",
        "              metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qKF6uW-BCaW-",
        "colab_type": "text"
      },
      "source": [
        "## 训练模型\n",
        "\n",
        "训练神经网络模型需要以下步骤:\n",
        "\n",
        "1. 将训练数据提供给模型 - 在本案例中，他们是`train_images`和`train_labels`数组。\n",
        "2. 模型学习如何将图像与其标签关联\n",
        "3. 我们使用模型对测试集进行预测, 在本案例中为`test_images`数组。我们验证预测结果是否匹配`test_labels`数组中保存的标签。\n",
        "\n",
        "通过调用`model.fit`方法来训练模型 — 模型对训练数据进行\"拟合\"。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xvwvpA64CaW_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model.fit(train_images, train_labels, epochs=5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W3ZVOhugCaXA",
        "colab_type": "text"
      },
      "source": [
        "随着模型训练，将显示损失和准确率等指标。该模型在训练数据上达到约0.88(或88％)的准确度。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oEw4bZgGCaXB",
        "colab_type": "text"
      },
      "source": [
        "## 评估准确率\n",
        "\n",
        "接下来，比较模型在测试数据集上的执行情况:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VflXLEeECaXC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)\n",
        "\n",
        "print('Test accuracy:', test_acc)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yWfgsmVXCaXG",
        "colab_type": "text"
      },
      "source": [
        "事实证明，测试数据集的准确性略低于训练数据集的准确性。训练精度和测试精度之间的差距是*过拟合*的一个例子。过拟合是指机器学习模型在新数据上的表现比在训练数据上表现更差。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xsoS7CPDCaXH",
        "colab_type": "text"
      },
      "source": [
        "## 进行预测\n",
        "\n",
        "通过训练模型，我们可以使用它来预测某些图像。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gl91RPhdCaXI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "predictions = model.predict(test_images)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x9Kk1voUCaXJ",
        "colab_type": "text"
      },
      "source": [
        "在此，模型已经预测了测试集中每个图像的标签。我们来看看第一个预测:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3DmJEUinCaXK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "predictions[0]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-hw1hgeSCaXN",
        "colab_type": "text"
      },
      "source": [
        "预测是10个数字的数组。这些描述了模型的\"信心\"，即图像对应于10种不同服装中的每一种。我们可以看到哪个标签具有最高的置信度值："
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qsqenuPnCaXO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "np.argmax(predictions[0])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E51yS7iCCaXO",
        "colab_type": "text"
      },
      "source": [
        "因此，模型最有信心的是这个图像是ankle boot，或者 `class_names[9]`。 我们可以检查测试标签，看看这是否正确:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Sd7Pgsu6CaXP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "test_labels[0]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ygh2yYC972ne",
        "colab_type": "text"
      },
      "source": [
        "我们可以用图表来查看全部10个类别"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DvYmmrpIy6Y1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def plot_image(i, predictions_array, true_label, img):\n",
        "  predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]\n",
        "  plt.grid(False)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  \n",
        "  plt.imshow(img, cmap=plt.cm.binary)\n",
        "  \n",
        "  predicted_label = np.argmax(predictions_array)\n",
        "  if predicted_label == true_label:\n",
        "    color = 'blue'\n",
        "  else:\n",
        "    color = 'red'\n",
        "  \n",
        "  plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n",
        "                                100*np.max(predictions_array),\n",
        "                                class_names[true_label]),\n",
        "                                color=color)\n",
        "\n",
        "def plot_value_array(i, predictions_array, true_label):\n",
        "  predictions_array, true_label = predictions_array[i], true_label[i]\n",
        "  plt.grid(False)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n",
        "  plt.ylim([0, 1])\n",
        "  predicted_label = np.argmax(predictions_array)\n",
        "  \n",
        "  thisplot[predicted_label].set_color('red')\n",
        "  thisplot[true_label].set_color('blue')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d4Ov9OFDMmOD",
        "colab_type": "text"
      },
      "source": [
        "让我们看看第0个图像，预测和预测数组。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HV5jw-5HwSmO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "i = 0\n",
        "plt.figure(figsize=(6,3))\n",
        "plt.subplot(1,2,1)\n",
        "plot_image(i, predictions, test_labels, test_images)\n",
        "plt.subplot(1,2,2)\n",
        "plot_value_array(i, predictions,  test_labels)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ko-uzOufSCSe",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "i = 12\n",
        "plt.figure(figsize=(6,3))\n",
        "plt.subplot(1,2,1)\n",
        "plot_image(i, predictions, test_labels, test_images)\n",
        "plt.subplot(1,2,2)\n",
        "plot_value_array(i, predictions,  test_labels)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kgdvGD52CaXR",
        "colab_type": "text"
      },
      "source": [
        "让我们绘制几个图像及其预测结果。正确的预测标签是蓝色的，不正确的预测标签是红色的。该数字给出了预测标签的百分比(满分100)。请注意，即使非常自信，也可能出错。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hQlnbqaw2Qu_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 绘制前X个测试图像，预测标签和真实标签\n",
        "# 以蓝色显示正确的预测，红色显示不正确的预测\n",
        "num_rows = 5\n",
        "num_cols = 3\n",
        "num_images = num_rows*num_cols\n",
        "plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n",
        "for i in range(num_images):\n",
        "  plt.subplot(num_rows, 2*num_cols, 2*i+1)\n",
        "  plot_image(i, predictions, test_labels, test_images)\n",
        "  plt.subplot(num_rows, 2*num_cols, 2*i+2)\n",
        "  plot_value_array(i, predictions, test_labels)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R32zteKHCaXT",
        "colab_type": "text"
      },
      "source": [
        "最后，使用训练的模型对单个图像进行预测。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yRJ7JU7JCaXT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 从测试数据集中获取图像\n",
        "img = test_images[0]\n",
        "\n",
        "print(img.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vz3bVp21CaXV",
        "colab_type": "text"
      },
      "source": [
        "`tf.keras`模型经过优化，可以一次性对*批量*,或者一个集合的数据进行预测。因此，即使我们使用单个图像，我们也需要将其添加到列表中:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lDFh5yF_CaXW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 将图像添加到批次中，即使它是唯一的成员。\n",
        "img = (np.expand_dims(img,0))\n",
        "\n",
        "print(img.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EQ5wLTkcCaXY",
        "colab_type": "text"
      },
      "source": [
        "现在来预测图像:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o_rzNSdrCaXY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "predictions_single = model.predict(img)\n",
        "\n",
        "print(predictions_single)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6Ai-cpLjO-3A",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plot_value_array(0, predictions_single, test_labels)\n",
        "plt.xticks(range(10), class_names, rotation=45)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cU1Y2OAMCaXb",
        "colab_type": "text"
      },
      "source": [
        "`model.predict`返回一个包含列表的列表，每个图像对应一个列表的数据。获取批次中我们(仅有的)图像的预测:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2tRmdq_8CaXb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "prediction_result = np.argmax(predictions_single[0])\n",
        "print(prediction_result)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YFc2HbEVCaXd",
        "colab_type": "text"
      },
      "source": [
        "而且，和之前一样，模型预测标签为9。"
      ]
    }
  ]
}